PyTorch LightningのTrainerの仕組み
https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#under-the-hood
#PyTorch_Lightning
code: python
# put model in train mode
model.train()
torch.set_grad_enabled(True)
losses = []
for batch in train_dataloader:
# calls hooks like this one
on_train_batch_start()
# train step
loss = training_step(batch)
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# update parameters
optimizer.step()
losses.append(loss)
より詳細 https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks
LightningModuleのフックの一覧でもある
pl.LightningModuleの__call__はnn.Moduleに定義された__call__
語弊を恐れずに言えば、forwardメソッドを呼び出す